import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self, yield_stress: float = 2.16, shear_modulus: float = 28.0):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize with best values from iterative feedback.

        Args:
            yield_stress (float): yield stress threshold for plastic flow.
            shear_modulus (float): shear modulus for plastic correction.
        """
        super().__init__()
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))
        self.shear_modulus = nn.Parameter(torch.tensor(shear_modulus))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute corrected deformation gradient from deformation gradient tensor using von Mises plasticity on
        logarithmic deviatoric principal strains.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        # SVD of deformation gradient F
        U, sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), sigma: (B,3), Vh: (B,3,3)
        sigma = torch.clamp_min(sigma, 1e-6)  # clamp to prevent log(0), (B,3)

        # Compute principal logarithmic strains
        epsilon = torch.log(sigma)  # (B,3)

        # Volumetric (mean) strain
        epsilon_mean = epsilon.mean(dim=1, keepdim=True)  # (B,1)

        # Deviatoric strains
        epsilon_dev = epsilon - epsilon_mean  # (B,3)

        # Norm of deviatoric strain
        epsilon_dev_norm = epsilon_dev.norm(dim=1, keepdim=True) + 1e-12  # (B,1)

        # Clamp plasticity parameters to prevent numerical issues
        yield_stress = torch.clamp_min(self.yield_stress, 1e-6)
        shear_modulus = torch.clamp_min(self.shear_modulus, 1e-6)

        # Plastic multiplier
        delta_gamma = epsilon_dev_norm - yield_stress / (2 * shear_modulus)  # (B,1)
        delta_gamma_pos = torch.clamp_min(delta_gamma, 0.0)  # (B,1)

        # Correct deviatoric strains by return mapping if yielding
        epsilon_corrected = epsilon - (delta_gamma_pos / epsilon_dev_norm) * epsilon_dev  # (B,3)

        # Where not yielding, keep original strain
        yielding_mask = (delta_gamma > 0).view(-1, 1)  # (B,1)
        epsilon_final = torch.where(yielding_mask, epsilon_corrected, epsilon)  # (B,3)

        # Reconstruct corrected singular values and deformation gradient
        sigma_corrected = torch.exp(epsilon_final)  # (B,3)
        diag_sigma_corrected = torch.diag_embed(sigma_corrected)  # (B,3,3)

        F_corrected = torch.matmul(U, torch.matmul(diag_sigma_corrected, Vh))  # (B,3,3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(self, youngs_modulus_log: float = 11.7, poissons_ratio_logit: float = -0.7):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize with values inferred from analysis.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio_logit (float): pre-sigmoid parameter for Poisson's ratio.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))
        self.poissons_ratio_logit = nn.Parameter(torch.tensor(poissons_ratio_logit))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient tensor using St. Venant-Kirchhoff elasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.shape[0]
        device = F.device
        dtype = F.dtype

        # Compute Young's modulus from log
        youngs_modulus = torch.exp(self.youngs_modulus_log)  # scalar

        # Compute Poisson's ratio from sigmoid(logit) scaled to (0,0.49)
        poissons_ratio = torch.sigmoid(self.poissons_ratio_logit) * 0.49  # scalar in (0,0.49)

        # Compute Lamé parameters mu and lambda
        mu = youngs_modulus / (2 * (1 + poissons_ratio))  # scalar
        la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))  # scalar

        # Identity tensor expanded to batch size
        I = torch.eye(3, dtype=dtype, device=device).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)

        # Right Cauchy-Green tensor C = F^T F
        Ft = F.transpose(1, 2)  # (B,3,3)
        C = torch.matmul(Ft, F)  # (B,3,3)

        # Green-Lagrange strain E = 0.5 * (C - I)
        E = 0.5 * (C - I)  # (B,3,3)

        # Trace of E computed by summing diagonal elements
        trE = E.diagonal(dim1=1, dim2=2).sum(dim=1).view(B, 1, 1)  # (B,1,1)

        # Second Piola-Kirchhoff stress tensor S
        S = 2 * mu * E + la * trE * I  # (B,3,3)

        # First Piola-Kirchhoff stress tensor P = F @ S
        P = torch.matmul(F, S)  # (B,3,3)

        # Kirchhoff stress tensor tau = P @ F^T
        kirchhoff_stress = torch.matmul(P, Ft)  # (B,3,3)

        return kirchhoff_stress
